Pytorch中DataLoader类的多线程实现方法分析

您所在的位置:网站首页 dataloader workers Pytorch中DataLoader类的多线程实现方法分析

Pytorch中DataLoader类的多线程实现方法分析

2023-03-27 15:47| 来源: 网络整理| 查看: 265

Pytorch中DataLoader类的多线程实现⽅法分析

之前在改⾃定义的DataSet的时候,由于在getitem()⾥⾯写了太多操作,导致训练过程贼慢,于是考虑⽤多线程优化⼀下。查阅⼀些资料发现

pytorch在DataLoader⾥⾯就有多线程的实现,只要在定义的时候将num_worker设置成⼤于0就可以了。遂想要探索⼀下pytorch具体的实现

⽅法。

⾸先找到迭代器:

def __iter__(self):

    return _DataLoaderIter(self)

初始化:

def __init__(self, loader):

    self.dataset = loader.dataset

    self.collate_fn = loader.collate_fn

    self.batch_sampler = loader.batch_sampler

    self.num_workers = loader.num_workers

    self.pin_memory = loader.pin_memory and torch.cuda.is_available()

    self.timeout = loader.timeout

    self.done_event = threading.Event()

    self.sample_iter = iter(self.batch_sampler)

    base_seed = torch.LongTensor(1).random_().item()

collate_fn:将数据整合成⼀个batch返回的⽅法,⽤户可以⾃定义

batch_sampler:⾃定义如何取样

pin_menory:是否将数据集拷贝到显卡上

done_event:事件管理标志

sample_iter:迭代器,所以batch_sampler应该类似于⽤户⾃定义的⼀个数据的列表,⽤来⽣成可迭代对象sample_iter。

下⾯是与多线程有关的⼀些定义:

if self.num_workers > 0:

    self.worker_init_fn = loader.worker_init_fn

    self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]

    self.worker_queue_idx = 0

    self.worker_result_queue = multiprocessing.SimpleQueue()

    self.batches_outstanding = 0

    self.worker_pids_set = False

    self.shutdown = False

    self.send_idx = 0

    self.rcvd_idx = 0

    self.reorder_dict = {}

    self.workers = [

         multiprocessing.Process(

            target=_worker_loop,

            args=(self.dataset, self.index_queues[i],

                  self.worker_result_queue, self.collate_fn, base_seed + i,

                  self.worker_init_fn, i))

            for i in range(self.num_workers)]

worker_init_fn:⽤户定义的每个worker初始化的时候需要执⾏的函数。

index_queues:这⾥⽤到了multiprocessing,pytorch的multiprocessing是对python原⽣的multiprocessing的⼀个封装,不过好像基本

没什么变化。这⾥定义⼀个队列,multiprocessing的Queue类(这个Queue的⽗类)提供了put()和get()⽅法,⽤来向队列中增加线程和移除

线程并返回结果。Pytorch的封装另外提供了send()和recv()⽅法,⽤来接收和读取缓存,具体实现和作⽤这⾥暂且按下不表。通过阅读后⾯的代



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3